import openai
import google.generativeai as genai
import torch
import asyncio
from transformers import AutoTokenizer, AutoModelForCausalLM

class APIClient:
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.models = {}

        if cfg.get("use_gpt_api"):
            openai.api_key = cfg["api_key"]["openai"]

        if cfg.get("use_gemini_api"):
            genai.configure(api_key=cfg["api_key"]["google"])

        if cfg.get("use_huggingface_login"):
            from huggingface_hub import login
            login(cfg["api_key"]["huggingface_hub"])

    def load_local_model(self, model_key):
        if model_key in self.models:
            return self.models[model_key]

        model_config = self.cfg["models"][model_key]
        tokenizer = AutoTokenizer.from_pretrained(model_config["tokenizer"], use_fast=True)
        model = AutoModelForCausalLM.from_pretrained(model_config["name"], torch_dtype=torch.bfloat16)
        model.to(self.device)
        self.models[model_key] = (model, tokenizer)
        return self.models[model_key]

    async def call(self, model_key, prompt, temp=0.5, max_tokens=400, for_graph=False):
        if self.cfg.get("use_gpt_api") and (for_graph is False or self.cfg.get("use_gpt_api_for_graph")):
            return await self.call_openai(prompt, model_key, temp, max_tokens)
        elif self.cfg["models"][model_key]["type"] == "local":
            return await self.call_local(model_key, prompt, temp, max_tokens)
        else:
            raise NotImplementedError("Unsupported model type or API flag")

    async def call_openai(self, prompt, model_name, temp, max_tokens):
        response = await asyncio.to_thread(
            openai.ChatCompletion.create,
            model=model_name,
            temperature=temp,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens
        )
        return response['choices'][0]['message']['content'].strip()

    async def call_local(self, model_key, prompt, temp, max_tokens):
        model, tokenizer = self.load_local_model(model_key)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = tokenizer.pad_token_id

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        outputs = await asyncio.to_thread(
            model.generate,
            input_ids,
            max_new_tokens=max_tokens,
            temperature=temp,
            use_cache=True,
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
